from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import os

class Inference:
    def __init__(self, model_name, gpu_id, model_path, base_model, device=None):
        self.model_name = model_name
        self.gpu_id = gpu_id
        self.model_path = model_path
        self.base_model = base_model

        try:
            visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")
            if visible_devices and visible_devices[0]:
                self.device = f"cuda:{gpu_id}"
                print(f"Using local GPU ID {gpu_id} for {model_name}")
            else:
                self.device = f"cuda:{gpu_id}"
                print(f"Using global GPU ID {gpu_id} for {model_name}")
        except Exception as e:
            print(f"Error determining GPU device: {e}")
            self.device = f"cuda:{gpu_id}"
        
        print(f"Loading model {model_name} on device {self.device}")

        try:
            torch.cuda.set_device(gpu_id)
            print(f"Set current device to: cuda:{gpu_id}")
            print(f"Current CUDA device: {torch.cuda.current_device()}")
            
            with torch.cuda.device(gpu_id):
                self.tokenizer = AutoTokenizer.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    padding_side='left'  
                )

                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token

                self.model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    torch_dtype=torch.float16,
                    low_cpu_mem_usage=True,
                    trust_remote_code=True,
                    device_map={'': self.device}
                )

                print(f"Model device: {next(self.model.parameters()).device}")
                
        except Exception as e:
            print(f"Error initializing model on GPU {gpu_id}: {e}")
            raise
        
    class PromptDataset(Dataset):
        def __init__(self, prompts, tokenizer, max_length=512):
            self.prompts = prompts
            self.tokenizer = tokenizer
            self.max_length = max_length
        
        def __len__(self):
            return len(self.prompts)
        
        def __getitem__(self, idx):
            return self.prompts[idx]
            
    def load_model(self):
        if self.model is None:
            print(f"Loading model {self.model_name} on device {self.device}")
            model_path = self.model_path
            
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16,
                local_files_only=True,
                trust_remote_code=True,
                use_safetensors=True,
                device_map={'': self.device}  
            )
            
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
    
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            self.model.eval()
    
    def generate_response(
        self,
        instruction,
        max_new_tokens=512,
        use_chat_template=False
    ):
        """Generate single response"""
        self.load_model()
        try:
            with torch.no_grad():
                if use_chat_template:
                    if self.base_model == "qwen":
                        messages = [
                            {"role": "system", "content": "You are a helpful AI assistant."},
                            {"role": "user", "content": instruction}
                        ]
                        text = self.tokenizer.apply_chat_template(
                            messages,
                            tokenize=False,
                            add_generation_prompt=True
                        )
                        model_inputs = self.tokenizer(
                            [text], 
                            return_tensors="pt"
                        ).to(self.device)
                        
                        generated_ids = self.model.generate(
                            **model_inputs,
                            max_new_tokens=max_new_tokens
                        )
                        generated_ids = [
                            output_ids[len(input_ids):] 
                            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
                        ]
                        response = self.tokenizer.batch_decode(
                            generated_ids, 
                            skip_special_tokens=True
                        )[0]
                    else:
                        chat = [
                            {"role": "user", "content": instruction},
                        ]
                        prompt = self.tokenizer.apply_chat_template(
                            chat, 
                            tokenize=False, 
                            add_generation_prompt=True
                        )
                        inputs = self.tokenizer.encode(
                            prompt, 
                            add_special_tokens=False, 
                            return_tensors="pt"
                        ).to(self.device)
                        
                        outputs = self.model.generate(
                            inputs,
                            max_new_tokens=max_new_tokens,
                            do_sample=False,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id
                        )
                        response = self.tokenizer.decode(
                            outputs[0][len(inputs[0]):], 
                            skip_special_tokens=True
                        ).strip()
                else:

                    full_prompt = f"Instruction: {instruction}\nResponse:"
                    inputs = self.tokenizer(
                        full_prompt,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=512
                    ).input_ids.to(self.device)
                    
                    outputs = self.model.generate(
                        inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=False,
                        pad_token_id=self.tokenizer.pad_token_id,
                        eos_token_id=self.tokenizer.eos_token_id
                    )
                    response = self.tokenizer.decode(
                        outputs[0][len(inputs[0]):], 
                        skip_special_tokens=True
                    ).strip()
                
                if "\nmodel\n" in response:
                    response = response.split("\nmodel\n")[-1].strip()
                    
                return response
                
        except Exception as e:
            print(f"Error generating response: {str(e)}")
            return f"Error generating response: {str(e)}"

    def cleanup(self):
        """Cleanup method to be called when done with all generations"""
        if hasattr(self, 'model') and self.model is not None:
            self.model.to('cpu')
            del self.model
            del self.tokenizer
            self.model = None
            self.tokenizer = None
            torch.cuda.empty_cache()

    def batch_generate_responses(
        self, 
        instructions,
        batch_size=10, 
        max_new_tokens=128,
        use_chat_template=False,
    ):
        try:
            responses = []
            max_batch_size = min(
                batch_size,
                max(1, int(torch.cuda.get_device_properties(self.gpu_id).total_memory * 0.7 
                          / (self.model.config.max_position_embeddings * 2048)))
            )

            if use_chat_template:
                if self.base_model == "qwen":
                    batch_messages = [
                        [
                            {"role": "system", "content": "You are a helpful AI assistant."},
                            {"role": "user", "content": instruction}
                        ]
                        for instruction in instructions
                    ]
                    batch_prompts = [
                        self.tokenizer.apply_chat_template(
                            messages,
                            tokenize=False,
                            add_generation_prompt=True
                        )
                        for messages in batch_messages
                    ]
                else:
                    batch_prompts = [
                        self.tokenizer.apply_chat_template(
                            [{"role": "user", "content": "Here is a question: " + instruction + "\nPlease give me reasonable response as solution step by step."}],
                            tokenize=False,
                            add_generation_prompt=True
                        )
                        for instruction in instructions
                    ]
            else:
                batch_prompts = [
                    f"Here is a question: {instruction}\n Please give me reasonable response."
                    for instruction in instructions
                ]

            dataset = self.PromptDataset(batch_prompts, self.tokenizer)
            dataloader = DataLoader(
                dataset, 
                batch_size=max_batch_size, 
                shuffle=False, 
                num_workers=0,
                pin_memory=True
            )

            for batch_prompts in tqdm(dataloader, desc="Processing batches"):
                with torch.no_grad(), torch.amp.autocast('cuda'):
                    model_inputs = self.tokenizer(
                        batch_prompts,
                        padding=True,
                        truncation=True,
                        max_length=512,
                        return_tensors="pt"
                    ).to(self.device)

                    if self.base_model == "qwen":
                        generated_ids = self.model.generate(
                            **model_inputs,
                            max_new_tokens=max_new_tokens,
                            do_sample=True,
                            temperature=0.7,
                            top_k=50,
                            top_p=0.95,
                        )
                        generated_ids = [
                            output_ids[len(input_ids):] 
                            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
                        ]
                        batch_responses = self.tokenizer.batch_decode(
                            generated_ids,
                            skip_special_tokens=True
                        )
                    else:
                        outputs = self.model.generate(
                            input_ids=model_inputs.input_ids,
                            attention_mask=model_inputs.attention_mask,
                            max_new_tokens=max_new_tokens,
                            do_sample=True,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id,
                            use_cache=True,
                            num_beams=1,
                            early_stopping=False,
                            return_dict_in_generate=False,
                            temperature=0.7,
                            top_k=50,
                            top_p=0.95,
                        )
                        batch_responses = [
                            self.tokenizer.decode(
                                output[model_inputs.input_ids[j].ne(self.tokenizer.pad_token_id).sum():],
                                skip_special_tokens=True
                            ).strip()
                            for j, output in enumerate(outputs)
                        ]
                    
                    batch_responses = [
                        response.split("\nmodel\n")[-1].strip() if "\nmodel\n" in response else response
                        for response in batch_responses
                    ]

                    responses.extend(batch_responses)

                    del model_inputs
                    if self.base_model == "qwen":
                        del generated_ids
                    else:
                        del outputs
                    torch.cuda.empty_cache()

            return responses
            
        except Exception as e:
            print(f"Error in batch generation: {e}")
            return [f"Error: {str(e)}"] * len(instructions)
        
    def judge_batch_generate_responses(
        self, 
        instructions,
        batch_size=10, 
        max_new_tokens=128,
        use_chat_template=False
    ):
        try:
            responses = []
            max_batch_size = min(
                batch_size,
                max(1, int(torch.cuda.get_device_properties(self.gpu_id).total_memory * 0.7 
                          / (self.model.config.max_position_embeddings * 2048)))
            )

            if use_chat_template:
                if self.base_model == "qwen":

                    batch_messages = [
                        [
                            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
                            {"role": "user", "content": instruction}
                        ]
                        for instruction in instructions
                    ]
                    batch_prompts = [
                        self.tokenizer.apply_chat_template(
                            messages,
                            tokenize=False,
                            add_generation_prompt=True
                        )
                        for messages in batch_messages
                    ]
                else:
                    batch_prompts = [
                        self.tokenizer.apply_chat_template(
                            [{"role": "user", "content": instruction}],
                            tokenize=False,
                            add_generation_prompt=True
                        )
                        for instruction in instructions
                    ]
            else:
                batch_prompts = instructions

            dataset = self.PromptDataset(batch_prompts, self.tokenizer)
            dataloader = DataLoader(
                dataset, 
                batch_size=max_batch_size, 
                shuffle=False, 
                num_workers=0,
                pin_memory=True
            )

            for batch_prompts in tqdm(dataloader, desc="Processing judge batches"):
                with torch.no_grad(), torch.amp.autocast('cuda'):
                    model_inputs = self.tokenizer(
                        batch_prompts,
                        padding=True,
                        truncation=True,
                        max_length=512,
                        return_tensors="pt"
                    ).to(self.device)

                    if self.base_model == "qwen":

                        generated_ids = self.model.generate(
                            **model_inputs,
                            max_new_tokens=max_new_tokens,
                            do_sample=True,
                            temperature=0.7,
                            top_k=50,
                            top_p=0.95,
                        )
                        generated_ids = [
                            output_ids[len(input_ids):] 
                            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
                        ]
                        batch_responses = self.tokenizer.batch_decode(
                            generated_ids,
                            skip_special_tokens=True
                        )
                    else:
                        # 其他模型的生成方式
                        outputs = self.model.generate(
                            input_ids=model_inputs.input_ids,
                            attention_mask=model_inputs.attention_mask,
                            max_new_tokens=max_new_tokens,
                            do_sample=True,
                            pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.eos_token_id,
                            use_cache=True,
                            num_beams=1,
                            early_stopping=False,
                            return_dict_in_generate=False,
                            temperature=0.7,
                            top_k=50,
                            top_p=0.95,
                        )
                        batch_responses = [
                            self.tokenizer.decode(
                                output[model_inputs.input_ids[j].ne(self.tokenizer.pad_token_id).sum():],
                                skip_special_tokens=True
                            ).strip()
                            for j, output in enumerate(outputs)
                        ]

                    batch_responses = [
                        response.split("Assistant:")[-1].strip() if "Assistant:" in response else response
                        for response in batch_responses
                    ]

                    responses.extend(batch_responses)

                    del model_inputs
                    if self.base_model == "qwen":
                        del generated_ids
                    else:
                        del outputs
                    torch.cuda.empty_cache()

            return responses
            
        except Exception as e:
            print(f"Error in judge batch generation: {e}")
            return [f"Error: {str(e)}"] * len(instructions)

